package org.apache.spark.sql.execution.datasources.restjson.demo1

import java.net.URL
import java.util.function.Consumer

import com.google.gson.JsonParser
import org.apache.http.client.fluent.Request
import org.apache.http.util.EntityUtils
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
import org.apache.spark.sql.{Row, SQLContext, SparkSession}

import scala.collection.mutable.ArrayBuffer

trait JsonProcessor {
  // 将 scala 的 lambda 表达式转成 java 的 Consumer 便于对 java 的 iterator 进行 foreach
  implicit def toConsumer[A](function: A => Unit): Consumer[A] = new Consumer[A]() {
    override def accept(arg: A): Unit = function.apply(arg)
  }

  def parse(content: String): ArrayBuffer[Row] = {
    val rows = ArrayBuffer[Row]()
    rows += Row.fromSeq(Seq(content))
    rows
  }
}


//abstract class DefaultSource extends RelationProvider with DataSourceRegister with JsonTextParser with Logging
//  with Serializable {
class DefaultSource extends SchemaRelationProvider with DataSourceRegister with JsonProcessor with Logging
  with Serializable {
  // 用于 RelationProvider
  // override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {

  override def createRelation(sqlContext: SQLContext, parameters: Map[String, String],
                              schema: StructType): BaseRelation = {
    // val samplingRatio = parameters.get("samplingRatio").map(_.toDouble).getOrElse(1.0)
    //这里是直接通过potions传递过来的。
    val path = parameters.getOrElse("path", "")
    new RestJSONRelation(path, Some(schema))(sqlContext)
  }

  override def shortName(): String = "restJson"

  private[sql] class RestJSONRelation(val path: String, val structType: Option[StructType])(@transient val sqlContext: SQLContext)
    extends BaseRelation with TableScan {
    override def schema: StructType = structType.get

    //override def schema: StructType = structType.getOrElse(defaultSchema)

    // 使用 lazy 避免重复计算 schema, 此处后续可以设置成用户通过optional 传递schema
    //lazy val defaultSchema: StructType = StructType(Seq(StructField("body", StringType, true)))


    private def createBaseRdd(inputPaths: Array[String]): RDD[Row] = {
      val url = inputPaths.head
      val res = Request.Get(new URL(url).toURI).execute()
      val response = res.returnResponse()
      val content = EntityUtils.toString(response.getEntity)
      if (response != null && response.getStatusLine.getStatusCode == 200) {
        val rows = parse(content)
        sqlContext.sparkContext.makeRDD(rows)
      } else {
        sqlContext.sparkContext.makeRDD(Seq())
      }
    }

    override def buildScan(): RDD[Row] = {
      createBaseRdd(Array(path))
    }
  }

}


class MyDataSource extends DefaultSource {
  override def parse(content: String) = {
    val rows = ArrayBuffer[Row]()
    //这里是做数据抽取的，把data的数组给抽取出来,业务逻辑耦合,后续可以考虑进行通用化改进
    (new JsonParser()).parse(content).getAsJsonObject.get("files").getAsJsonArray.forEach(toConsumer(x => {
      val obj = x.getAsJsonObject
      rows += Row.fromSeq(Seq(obj.get("name").getAsString, obj.get("directory").getAsBoolean))
    }))
    rows
  }
}

/**
  * Created by peibin on 2017/8/4.
  * 自定义 spark sql data source api
  */
object RestSourceDemo {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession
      .builder()
      .appName("RestSource")
      .master("local[*]")
      .getOrCreate()


    // 详情可以看 DataSource 的 lookupDataSource 函数
    val df = spark.read.format("org.apache.spark.sql.execution.datasources.restjson.demo2.MyDataSource").options(Map(
      "path" -> "http://10.199.212.80:8400/esperhqapp/hqapi/vfs/file"
    )).schema(StructType(Seq(StructField("name2", StringType, true), StructField("directory2", BooleanType, true)))).load()
    // 或者 val df = spark.read.format("org.apache.spark.sql.execution.datasources.restjson").load("http://10.199.212.80:8400/esperhqapp/hqapi/vfs/file")
    df.show()
  }


}
